{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "documented-senator",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.resnet_sgd_cosineannealing_inference import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "international-possibility",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define input image dimensions for resizing\n",
    "height = 448\n",
    "width = 448\n",
    "\n",
    "# Define model hyperparameters\n",
    "lr = 0.01\n",
    "momentum = 0.9\n",
    "T_0 = 225 # e.g. 899 / 4 (train_dataset_size / batch_size)\n",
    "T_mult = 1\n",
    "epochs = 5\n",
    "batch_size = 4 # For both train and test sets\n",
    "\n",
    "# Define number of layers for the ResNet neural network, select from [18, 34, 50 ,101, 152]\n",
    "num_layers = 18\n",
    "\n",
    "pretrained_weights = True\n",
    "unfreeze_all_layers = 'False' # i.e. Default: 'False', unfreezes last layer only for tuning\n",
    "\n",
    "train_augmentation = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "metric-governor",
   "metadata": {},
   "outputs": [],
   "source": [
    "bucket = None\n",
    "saved_model_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "monthly-richards",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"training_data/dogs/good_dogs_vs_not\"\n",
    "model_dir = \"models\"\n",
    "\n",
    "class _args:\n",
    "    image_height = height\n",
    "    image_width = width\n",
    "    train = os.path.join(data_dir, \"train\")\n",
    "    validation = os.path.join(data_dir, \"validation\")\n",
    "    test = os.path.join(data_dir, \"test\")\n",
    "    model_dir = model_dir\n",
    "    batch_size = batch_size\n",
    "    epochs = epochs\n",
    "    lr = lr\n",
    "    momentum = momentum\n",
    "    T_0 = T_0\n",
    "    T_mult = T_mult\n",
    "    num_layers = num_layers\n",
    "    pretrained_weights = pretrained_weights\n",
    "    s3_bucket = bucket\n",
    "    warm_restart = saved_model_path\n",
    "    unfreeze_all_layers = unfreeze_all_layers\n",
    "    train_augmentation = train_augmentation\n",
    "args = _args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ecological-middle",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = create_datasets(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "enormous-character",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "----------\n",
      "train Loss: 2.2175; train Acc: 0.5633;\n",
      "validation Loss: 0.7623; validation Acc: 0.7700;\n",
      "epoch: 1; lr: 0.0075;\n",
      "\n",
      "Epoch 2/5\n",
      "----------\n",
      "train Loss: 0.9440; train Acc: 0.7233;\n",
      "validation Loss: 1.5797; validation Acc: 0.6800;\n",
      "epoch: 2; lr: 0.0025000000000000014;\n",
      "\n",
      "Epoch 3/5\n",
      "----------\n",
      "train Loss: 1.1325; train Acc: 0.7033;\n",
      "validation Loss: 0.6557; validation Acc: 0.8400;\n",
      "epoch: 3; lr: 0.01;\n",
      "\n",
      "Epoch 4/5\n",
      "----------\n",
      "train Loss: 2.8620; train Acc: 0.6367;\n",
      "validation Loss: 4.6921; validation Acc: 0.5800;\n",
      "epoch: 4; lr: 0.0075;\n",
      "\n",
      "Epoch 5/5\n",
      "----------\n",
      "train Loss: 2.2068; train Acc: 0.6567;\n",
      "validation Loss: 1.4133; validation Acc: 0.7500;\n",
      "epoch: 5; lr: 0.0025000000000000014;\n",
      "\n",
      "Training complete in 9m 53s\n",
      "Best validation Acc: 0.840000\n",
      "models\\model.pth\n",
      "\n",
      "Evaluating best weights:\n",
      "--------------------\n",
      "train Loss: 0.2943 Acc: 0.8933\n",
      "train Avg. F1 Score: 0.893;\n",
      "classification_report: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    bad_dogs       0.93      0.85      0.89       150\n",
      "   good_dogs       0.86      0.93      0.90       150\n",
      "\n",
      "    accuracy                           0.89       300\n",
      "   macro avg       0.90      0.89      0.89       300\n",
      "weighted avg       0.90      0.89      0.89       300\n",
      ";\n",
      "\n",
      "validation Loss: 0.6557 Acc: 0.8400\n",
      "validation Avg. F1 Score: 0.838;\n",
      "classification_report: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    bad_dogs       0.93      0.74      0.82        50\n",
      "   good_dogs       0.78      0.94      0.85        50\n",
      "\n",
      "    accuracy                           0.84       100\n",
      "   macro avg       0.85      0.84      0.84       100\n",
      "weighted avg       0.85      0.84      0.84       100\n",
      ";\n",
      "\n",
      "test Loss: 0.3459 Acc: 0.8700\n",
      "test Avg. F1 Score: 0.870;\n",
      "classification_report: \n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    bad_dogs       0.88      0.86      0.87        50\n",
      "   good_dogs       0.86      0.88      0.87        50\n",
      "\n",
      "    accuracy                           0.87       100\n",
      "   macro avg       0.87      0.87      0.87       100\n",
      "weighted avg       0.87      0.87      0.87       100\n",
      ";\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train(args, datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bizarre-concrete",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
